import  os
import sys
import torch.nn as nn
import time
import numpy as np
import importlib
import torch
import torch.distributed as dist
from transformers import RobertaForSequenceClassification, RobertaTokenizer, RobertaConfig, GPT2ForSequenceClassification
from datasets import load_dataset
from tqdm import tqdm
import argparse
# import evaluate as eval_metric
import random
from peft import LoraConfig, get_peft_model
from myLoraModel.mapping import get_our_peft_model
from constants import  task_to_keys
from myLoraModel.MyLoRALayer import  MyLoRALayer

def setup(rank, world_size, port):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = str(port)
    dist.init_process_group(backend="nccl", world_size=world_size,  rank=rank)




def cleanup():
    dist.destroy_process_group()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--world_size", type=int, default=1, help="number of processes participating in the job")
    parser.add_argument("--rank", type=int, default=0, help="rank of the current process")
    parser.add_argument("--gpu", type=int, default=0, help="the gpu id to run")
    parser.add_argument("--seed", type=int, default=0, help="random seed")
    parser.add_argument("--fp16", action='store_true' , help="float point")
    parser.add_argument("--com_interval", type=int, default=1, help="The interval of communication")
    parser.add_argument("--name_or_path", type=str, default='roberta-base', help="The type of pretrained model")
    parser.add_argument("--dataset", type=str, default='sst2', help="The dataset")
    parser.add_argument("--method", type=str, default='pf2lora', help="The optimized method")
    parser.add_argument("--max_grad_norm", type=float, default=1.0, help="clip parameter")
    parser.add_argument("--lr_A", type=float, default=5e-5, help="learning rate for fine-turning lora_A")
    parser.add_argument("--lr_B", type=float, default=5e-5, help="learning rate for fine-turning lora_B")
    parser.add_argument("--lr_scheduler_type", type=str, default="linear", help="learning schedule type")
    parser.add_argument("--nu", type=float, default=5e-4, help="learning rate for fine-turning lora_B")
    parser.add_argument("--beta", type=float, default=0.99, help="momentum parameter for our method")
    parser.add_argument("--batch_size", type=int, default=32, help="batch size for fine-turning")
    parser.add_argument("--eval_batch_size", type=int, default=16, help="batch size for fine-turning")
    parser.add_argument("--num_epochs", type=int, default=2, help="number of epochs for fine-turning")
    parser.add_argument("--save_path", type=str, default='../output/', help="the path to save results")
    parser.add_argument("--port", type=int, default=12355, help="the port to communicate")
    parser.add_argument("--lamb", type=float, default=0.1, help="the balanced parameter for hetlora")
    parser.add_argument("--rank_mat", type=int, default=8, help="the rank for lora")
    parser.add_argument("--rank_max", type=int, default=15, help="the rank for lora")
    parser.add_argument("--rank_min", type=int, default=5, help="the rank for lora")
    parser.add_argument("--gamma", type=float, default=0.5, help="the sparsity of hetlora")
    parser.add_argument("--heterogeneity", type=float, default=0.7, help="the dissimilarity of different client")
    parser.add_argument("--inner_loops", type=int, default=10, help="the number of inner loops")
    parser.add_argument("--z_loops", type=int, default=5, help="the number of loops for updating linear system solution")
    parser.add_argument("--num_gpu", type=int, default=4, help="the number of gpu devices")
    parser.add_argument("--process_per_gpu", type=int, default=2, help="the number of gpu devices")
    parser.add_argument("--hessian_q", type=int, default=5, help="the number of q loops")
    parser.add_argument("--clients_per_gpu", type=int, default=2, help="simulate the multiple clients on per gpu")
    parser.add_argument("--com_rounds", type=int, default=100, help="the total communication rounds")
    parser.add_argument("--maml_in_lr", type=float, default=0.01, help="the total communication rounds")
    args = parser.parse_args()

    setup(args.rank, args.world_size, args.port)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # set the hyperparameters
    if args.dataset =='cola':
        args.com_rounds = 50
        args.heterogeneity = 0.3
        if args.method == 'hetlora':
            args.lr_B = 5e-3
            args.rank_max = 12
            args.rank_min = 8
            args.lamb = 1e-3
        if args.method == 'pf2lora':
            args.lr_B = 2e-3
            args.lr_A = 1e-4

    elif args.dataset == 'mnli':
        args.com_rounds = 300
        if args.method == 'hetlora':
            args.lr_B = 2e-3
            args.rank_max = 12
            args.rank_min = 5
            args.lamb = 5e-3
        if args.method == 'pf2lora':
            args.lr_B = 1e-3
            args.lr_A = 1e-3

    elif args.dataset == 'sst2':
        args.com_rounds = 100
        if args.method == 'hetlora':
            args.lr_B = 2e-3
            args.rank_max = 12
            args.rank_min = 5
            args.lamb = 5e-3
        if args.method == 'pf2lora':
            args.lr_B = 1e-3
            args.lr_A = 1e-3

    elif args.dataset == 'qqp':
        args.com_rounds = 300
        if args.method == 'hetlora':
            args.lr_B = 2e-3
            args.rank_max = 12
            args.rank_min = 5
            args.lamb = 5e-3
        if args.method == 'pf2lora':
            args.lr_B = 1e-3
            args.lr_A = 1e-3

    else:
        args.com_rounds = 100
        if args.method == 'hetlora':
            args.lr_B = 2e-3
            args.rank_max = 12
            args.rank_min = 5
            args.lamb = 5e-3

        if args.method == 'ours_one_step':
            args.lr_B = 1e-3
            args.lr_A = 1e-3
    # rank initialization for hetlora
    if args.method=='hetlora':
        args.rank_mat = args.rank_min + int((args.rank_max - args.rank_min) * args.rank / args.world_size)

    device = torch.device(f"cuda:{args.gpu}")
    torch.cuda.set_device(device)
    raw_dataset = load_dataset("glue", name=args.dataset)

    if args.dataset is not None:
        label_list = raw_dataset["train"].features["label"].names
        num_labels = len(label_list)
    else:
        label_list = raw_dataset["train"]["label"]
        num_labels = len(set(label_list))
    torch_dtype = (
        torch.float32
        if args.fp16
        else (torch.bfloat16 if args.fp16 else torch.float32)
    )

    model_config = RobertaConfig.from_pretrained(args.name_or_path, num_labels=num_labels, torch_dtype=torch_dtype)
    model = RobertaForSequenceClassification.from_pretrained(args.name_or_path, config=model_config)
    tokenizer = RobertaTokenizer.from_pretrained(args.name_or_path)

    save_direct = os.path.join(args.save_path, f'{args.name_or_path}/{args.dataset}')
    if not os.path.exists(save_direct):
        os.makedirs(save_direct)
    date = time.strftime('%Y-%m-%d-%H-%M', time.localtime(time.time()))
    file_name = f'{args.method}_lr_B{args.lr_B}_het{args.heterogeneity}_{date}'
    args.save_path = os.path.join(save_direct, file_name)
    alpha = 8 if args.name_or_path == 'roberta-base' else 16
    peft_config = LoraConfig(
        task_type="SEQ_CLS",
        fan_in_fan_out=False,
        target_modules=['query', 'value'],
        r=args.rank_mat,
        lora_alpha=alpha,
        lora_dropout=0.1,
        use_original_init=False,
        modules_to_save=["classifier", "score"],
    )

    if args.method == 'pf2lora':
        model = get_our_peft_model(model, peft_config)

    else:
        model = get_peft_model(model, peft_config)
    # print(model)
    # Manually freeze dense layer in RoBERTa
    if "roberta" in args.name_or_path:
        for name, param in model.named_parameters():
            if "classifier" in name and "dense" in name:
                param.requires_grad = False

    # Split the dataset into heterogeneous subsets for each client
    train_datasets = []
    eval_datasets = []

    if args.dataset is not None:
        sentence1_key, sentence2_key = task_to_keys[args.dataset]
    else:
        # Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
        non_label_column_names = [
            name for name in raw_dataset["train"].column_names if name != "label"
        ]
        if (
                "sentence1" in non_label_column_names
                and "sentence2" in non_label_column_names
        ):
            sentence1_key, sentence2_key = "sentence1", "sentence2"
        else:
            if len(non_label_column_names) >= 2:
                sentence1_key, sentence2_key = non_label_column_names[:2]
            else:
                sentence1_key, sentence2_key = non_label_column_names[0], None

    def preprocess_function(examples):
        args = (
            (examples[sentence1_key],)
            if sentence2_key is None
            else (examples[sentence1_key], examples[sentence2_key])
        )
        result = tokenizer(
            *args, padding="max_length", max_length=512, truncation=True
        )
        # Map labels to IDs (not necessary for GLUE tasks)

        return result

    raw_dataset = raw_dataset.map(preprocess_function, batched=True)
    raw_dataset.set_format(type="torch")
    train_dataset = raw_dataset["train"]
    if args.dataset == 'mnli':
        eval_dataset = raw_dataset["validation_matched"]#].append(raw_dataset["validation_mismatched"])
    else:
        eval_dataset = raw_dataset["validation"]

    dist.barrier()
    # construct the heterogeneous data
    if args.heterogeneity > 0:
        train_indices = [[] for _ in range(num_labels)]
        eval_indices = [[] for _ in range(num_labels)]

        for i, label in enumerate(train_dataset["label"]):
            train_indices[label].append(i)
        for i, label in enumerate(eval_dataset["label"]):
            eval_indices[label].append(i)
        train_label_proportions = torch.tensor(
            [float(len(train_indices[i])) for i in range(num_labels)]
        )
        eval_label_proportions = torch.tensor(
            [float(len(eval_indices[i])) for i in range(num_labels)]
        )
        train_dataset_size = torch.sum(train_label_proportions)
        train_label_proportions /= train_dataset_size
        eval_dataset_size = torch.sum(eval_label_proportions)
        eval_label_proportions /= eval_dataset_size
        for l in range(num_labels):
            random.shuffle(train_indices[l])
        # divide samples from each label into iid pool and non-iid pool. Note that samples
        # in iid pool are shuffled while samples in non-iid pool are sorted by label.
        iid_pool = []
        non_iid_pool = []
        for i in range(num_labels):
            iid_split = int((1.0 - args.heterogeneity) * len(train_indices[i]))
            iid_pool += train_indices[i][:iid_split]
            non_iid_pool += train_indices[i][iid_split:]
        random.shuffle(iid_pool)
        # Allocate iid and non-iid samples to each worker.
        iid_start = 0
        non_iid_start = 0
        partition_size = int(train_dataset_size // args.world_size)
        print(f'The training dataset size: {train_dataset_size}')
        train_worker_idxs = [[] for _ in range(args.world_size)]
        train_lower_idxs = [[] for _ in range(args.world_size)]
        for j in range(args.world_size):
            num_iid = int((1.0 - args.heterogeneity) * partition_size)
            # print(f'===the partition size: {num_iid}')
            num_non_iid = partition_size - num_iid
            train_worker_idxs[j] += iid_pool[iid_start: iid_start + num_iid]
            train_worker_idxs[j] += non_iid_pool[non_iid_start: non_iid_start + num_non_iid]
            train_lower_idxs[j] += iid_pool[iid_start: iid_start + num_iid]
            train_lower_idxs[j] += non_iid_pool[non_iid_start: non_iid_start + num_non_iid]
            iid_start += num_iid
            non_iid_start += num_non_iid
            random.shuffle(train_worker_idxs[j])
            random.shuffle(train_lower_idxs[j])

        iid_pool = []
        non_iid_pool = []
        for i in range(num_labels):
            iid_split = int((1.0 - args.heterogeneity) * len(eval_indices[i]))
            iid_pool += eval_indices[i][:iid_split]
            non_iid_pool += eval_indices[i][iid_split:]
        random.shuffle(iid_pool)
        # Allocate iid and non-iid samples to each worker.
        iid_start = 0
        non_iid_start = 0
        partition_size = int(eval_dataset_size // args.world_size)
        eval_worker_idxs = [[] for _ in range(args.world_size)]
        for j in range(args.world_size):
            num_iid = int((1.0 - args.heterogeneity) * partition_size)
            num_non_iid = partition_size - num_iid
            eval_worker_idxs[j] += iid_pool[iid_start: iid_start + num_iid]
            eval_worker_idxs[j] += non_iid_pool[non_iid_start: non_iid_start + num_non_iid]
            iid_start += num_iid
            non_iid_start += num_non_iid
            random.shuffle(eval_worker_idxs[j])

        print(f'Rank: {args.rank}  training dataset size: {len(train_worker_idxs[args.rank])}')
        train_datasets = train_dataset.select(train_worker_idxs[args.rank])
        train_lower = train_dataset.select(train_lower_idxs[args.rank])
        eval_datasets = eval_dataset.select(eval_worker_idxs[args.rank])

    dist.barrier()
    # load method
    method = importlib.import_module('./' + args.method)
    model = method.Model(args, model, dist)
    # start to train
    if args.method == 'pf2lora':
        model.train(train_datasets, eval_datasets, train_lower)
    else:
        model.train(train_datasets, eval_datasets)
    cleanup()

